{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Code based on the article: https://arxiv.org/abs/1903.10176\n",
    "## Important Note:\n",
    "Our algorithm integrates BM3D as RED engine with DIP network.\n",
    "\n",
    "Follow these 3 steps:\n",
    "    1. download the BM3D code: http://www.cs.tut.fi/~foi/GCF-BM3D/ and put in \"matlab_codes\" folder\n",
    "    2. Download and install the Matlab engine, see:\n",
    "        https://www.mathworks.com/help/matlab/matlab_external/install-matlab-engine-api-for-python-in-nondefault-locations.html\n",
    "    3. Edit the BM3D Matlab code (and the wrapper function), so you will be able to use it an external denoiser. \n",
    "        We are saving arrays as \".npy\" and reading them in Matlab code.\n",
    "        We found that it's the fastest way to work. Working this way requires the code:\n",
    "        https://github.com/kwikteam/npy-matlab\n",
    "        just put all *. files in matlab_codes folder. You may work with \".mat\" instead (it's a bit slower tho)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from threading import Thread  # needed since the denoiser is running in parallel\n",
    "import queue\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim\n",
    "from models.skip import skip  # our network\n",
    "\n",
    "from utils.utils import *  # auxiliary functions\n",
    "from utils.data import Data  # class that holds img, psnr, time\n",
    "from models.downsampler import Downsampler\n",
    "\n",
    "from skimage.restoration import denoise_nl_means, estimate_sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Matlab:\n",
    "from utils.matlab_utils import *\n",
    "os.environ['TMP'] = r'tmp'\n",
    "os.environ['TEMP'] = r'tmp'\n",
    "prevdir = os.getcwd()\n",
    "os.chdir(\"matlab_codes/\")  # make sure all the matlab codes are in this directory\n",
    "eng = matlab.engine.start_matlab(\"-nojvm\")  # this may take a while. To close session use: eng.quit()\n",
    "os.chdir(prevdir)\n",
    "print(\"Successfully loaded\", eng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CUDA_FLAG = False  # got GPU?\n",
    "CUDNN = False  # if you are not getting the exact article results set CUDNN to False\n",
    "if CUDA_FLAG:\n",
    "    os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "    # GPU accelerated functionality for common operations in deep neural nets\n",
    "    torch.backends.cudnn.enabled = CUDNN\n",
    "    # benchmark mode is good whenever your input sizes for your network do not vary.\n",
    "    # This way, cudnn will look for the optimal set of algorithms for that particular \n",
    "    # configuration (which takes some time). This usually leads to faster runtime.\n",
    "    # But if your input sizes changes at each iteration, then cudnn will benchmark every\n",
    "    # time a new size appears, possibly leading to worse runtime performances.\n",
    "    torch.backends.cudnn.benchmark = CUDNN\n",
    "    # torch.backends.cudnn.deterministic = True\n",
    "    dtype = torch.cuda.FloatTensor\n",
    "else:\n",
    "    dtype = torch.FloatTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FACTOR = 4  # scale factor, factor > 1, org HR -> LR --train--> est HR, compare psnr with org HR vs est HR on Y. ch\n",
    "            #               factor < 1, LR --train--> est HR, in that case ignore the psnrs, and enjoy your image\n",
    "# Algorithm NAMES (to get the relevant image: use data_dict[alg_name].img)\n",
    "# for example use data_dict['Clean'].img to get the clean image\n",
    "ORIGINAL  = 'HR'\n",
    "CORRUPTED = 'LR'\n",
    "BICUBIC   = 'Bicubic'\n",
    "DIP_BM3D  = 'DRED (BM3D)' "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load image for Super-Resolution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_LR_HR_imgs_sr(fname, factor, use_matlab_resize=False, plot=False):  # matlab_resize requires matlab engine\n",
    "    \"\"\"  Loads an image, resize it, center crops and upscale / downscales depends on factor (i.e > 1 or < 1)\n",
    "    Args: \n",
    "         fname: path to the image\n",
    "         factor: downscaling factor, you can pass factors smaller than 1 (i.e 0.25) it will scale up\n",
    "    Out:\n",
    "         dictionary of 3 kinds of images: 'LR', 'HR', 'Bicubic', # you may add: 'Sharp', 'Nearest'\n",
    "    \"\"\"\n",
    "    if factor > 1:  # we have the original (HR) image, lets get the LR image\n",
    "        img_HR_pil, img_HR_np = load_and_crop_image(fname, d=factor)  # crop by factor\n",
    "        if use_matlab_resize:\n",
    "            img_LR_pil, img_LR_np = matlab_resize(eng, img_HR_np, 1. / factor)\n",
    "        else:\n",
    "            img_LR_pil, img_LR_np = pil_resize(img_HR_pil, factor)\n",
    "    else:  # we have the small image, lets get the HR image\n",
    "        img_LR_pil, img_LR_np = load_and_crop_image(fname)\n",
    "        if use_matlab_resize:\n",
    "            img_HR_pil, img_HR_np = matlab_resize(eng, img_LR_np, 1. / factor)\n",
    "        else:\n",
    "            img_HR_pil, img_HR_np = pil_resize(img_LR_pil, int(1 // factor), downscale=False)\n",
    "    # Gets `bicubic` baseline\n",
    "    img_bicubic_pil = img_LR_pil.resize(img_HR_pil.size, Image.BICUBIC)\n",
    "    img_bicubic_np = pil_to_np(img_bicubic_pil)\n",
    "    data_dict = {ORIGINAL: Data(img_HR_np), CORRUPTED: Data(img_LR_np),\n",
    "                 BICUBIC: Data(img_bicubic_np, compare_psnr_y(img_HR_pil, img_bicubic_pil))\n",
    "                 }\n",
    "    if plot:\n",
    "        print('HR and LR resolutions:', img_HR_np.shape, img_LR_np.shape)\n",
    "        plot_dict(data_dict)\n",
    "    return data_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the LR and HR images\n",
    "data_dict = load_LR_HR_imgs_sr('datasets/Color Set14/flowers.png', factor=FACTOR, plot=True)\n",
    "if FACTOR < 1:\n",
    "    FACTOR = int(1 // FACTOR)  # after loaded, we fix the factor so other algorithm will work"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# THE NETWORK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_network_and_input(img_shape, input_depth=32, pad='reflection',\n",
    "                          upsample_mode='bilinear', use_interpolate=True, align_corners=False,\n",
    "                          act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4,\n",
    "                          num_scales=5, downsample_mode='stride', INPUT='noise'):  # 'meshgrid'\n",
    "    \"\"\" Getting the relevant network and network input (based on the image shape and input depth)\n",
    "        We are using the same default params as in DIP article\n",
    "        img_shape - the image shape (ch, x, y)\n",
    "    \"\"\"\n",
    "    n_channels = img_shape[0]\n",
    "    net = skip(input_depth, n_channels,\n",
    "               num_channels_down=[skip_n33d] * num_scales if isinstance(skip_n33d, int) else skip_n33d,\n",
    "               num_channels_up=[skip_n33u] * num_scales if isinstance(skip_n33u, int) else skip_n33u,\n",
    "               num_channels_skip=[skip_n11] * num_scales if isinstance(skip_n11, int) else skip_n11,\n",
    "               upsample_mode=upsample_mode, use_interpolate=use_interpolate, align_corners=align_corners,\n",
    "               downsample_mode=downsample_mode, need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun).type(dtype)\n",
    "    net_input = get_noise(input_depth, INPUT, img_shape[1:]).type(dtype).detach()\n",
    "    return net, net_input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BM3D denoiser (This wrapper function for the original Matlab code)\n",
    "#### make sure you edit the original code to work with this function / or alternatively create your own wrapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bm3d(noisy_np_img, sigma):\n",
    "    np.save('matlab_codes/noisy.npy', noisy_np_img.transpose(1, 2, 0))\n",
    "    if noisy_np_img.shape[0] == 3:  # Color BM3D (3D)\n",
    "        eng.CBM3D_denoise(float(sigma), nargout=0)\n",
    "        denoised_img = np.load('matlab_codes/denoised.npy').astype(np.float32)\n",
    "        return denoised_img.transpose(2, 0, 1)\n",
    "    # 2D BM3D:\n",
    "    eng.BM3D_denoise(float(sigma), nargout=0)\n",
    "    return np.expand_dims(np.load('matlab_codes/denoised.npy').astype(np.float32), axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Learning Powered by RED, Our Generic Function\n",
    "## The RED engine with Neural Network\n",
    "## you may test it with any neural net, and any denoiser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_via_admm(net, net_input, denoiser_function, D, y, bicubic, HR_img=None,  # D is the downsampler, y is LR image\n",
    "                   plot_array={}, algorithm_name=\"\", save_path=\"\",          # will save params and graphs in this folder\n",
    "                   admm_iter=2000, LR=0.001, update_iter=10, method='fixed_point',   # 'fixed_point' or 'grad' or 'mixed'\n",
    "                   sigma_f=5, beta=0.05, mu=0.06, LR_x=None, noise_factor=.02):    # LR_x needed only if method!=fixed_point\n",
    "    \"\"\" training the network using\n",
    "        ## Must Params ##\n",
    "        net                 - the network to be trained\n",
    "        net_input           - the network input\n",
    "        denoiser_function   - an external denoiser function, used as black box, this function\n",
    "                              must get numpy noisy image, and return numpy denoised image\n",
    "        D                   - the downsampler\n",
    "        y                   - LR image\n",
    "        bicubic             - bicubic of y (High Res. image\n",
    "        \n",
    "        # optional params #\n",
    "        HR_img              - the original image if exist for psnr compare only, or None (default)\n",
    "        plot_array          - prints params at the begging of the training and plot images at the required indices\n",
    "        algorithm_name      - the name that would show up while running, just to know what we are running ;)\n",
    "        admm_iter           - total number of admm epoch\n",
    "        LR                  - the learning rate of the network\n",
    "        sigma_f             - the sigma to send the denoiser function\n",
    "        update_iter         - denoised image updated every 'update_iter' iteration\n",
    "        method              - 'fixed_point' or 'grad' or 'mixed' \n",
    "                \n",
    "        # equation params #  \n",
    "        beta                - regularization parameter (lambda in the article)\n",
    "        mu                  - ADMM parameter\n",
    "        LR_x                - learning rate of the parameter x, needed only if method!=fixed point\n",
    "        # more\n",
    "        noise_factor       - the starting amount of noise added to the input of the network\n",
    "    \"\"\"\n",
    "    # get optimizer and loss function:\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=LR)  # using ADAM opt\n",
    "\n",
    "    mse = torch.nn.MSELoss().type(dtype)  # using MSE loss\n",
    "    # additional noise added to the input:\n",
    "    net_input_saved = net_input.detach().clone()\n",
    "    noise = net_input.detach().clone()\n",
    "    # x update method:\n",
    "    if method == 'fixed_point':\n",
    "        swap_iter = admm_iter + 1\n",
    "        LR_x = None\n",
    "    elif method == 'grad':\n",
    "        swap_iter = -1\n",
    "    elif method == 'mixed':\n",
    "        swap_iter = admm_iter // 2\n",
    "    else:\n",
    "        assert False, \"method can only be 'fixed_point' or 'grad' or 'mixed' !\"\n",
    "        \n",
    "    # initialize:\n",
    "    x = np.zeros_like(bicubic)  # bicubic.copy()\n",
    "    y_torch = np_to_torch(y).type(dtype)\n",
    "    f_x, u = x.copy(), np.zeros_like(x)\n",
    "    img_queue = queue.Queue()\n",
    "    denoiser_thread = Thread(target=lambda q, f, f_args: q.put(f(*f_args)),\n",
    "                             args=(img_queue, denoiser_function, [x.copy(), sigma_f]))\n",
    "    denoiser_thread.start()\n",
    "    for i in range(1, 1 + admm_iter):\n",
    "        # step 1, update network, eq. 7 in the article\n",
    "        optimizer.zero_grad()\n",
    "        net_input = net_input_saved + (noise.normal_() * noise_factor)\n",
    "        out = net(net_input)\n",
    "        out_np = torch_to_np(out)\n",
    "        # loss:\n",
    "        loss_y = mse(D(out), y_torch)\n",
    "        loss_x = mse(out, np_to_torch(x - u).type(dtype))\n",
    "        total_loss = loss_y + mu * loss_x\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        # step 2, update x using a denoiser and result from step 1\n",
    "        if i % update_iter == 0:  # the denoiser work in parallel\n",
    "            denoiser_thread.join()\n",
    "            f_x = img_queue.get()\n",
    "            denoiser_thread = Thread(target=lambda q, f, f_args: q.put(f(*f_args)),\n",
    "                                     args=(img_queue, denoiser_function, [x.copy(), sigma_f]))\n",
    "            denoiser_thread.start()\n",
    "\n",
    "        # step 2, update x using a the denoiser (f_x) and network outputs (out_np)\n",
    "        if i < swap_iter:\n",
    "            x = 1 / (beta + mu) * (beta * f_x + mu * (out_np + u))  # eq. 11 in the article\n",
    "        else:\n",
    "            x = x - LR_x * (beta * (x - f_x) + mu * (x - out_np - u))  # eq. 12 in the article\n",
    "        np.clip(x, 0, 1, out=x)  # making sure that image is in bounds\n",
    "\n",
    "        # step 3, update u\n",
    "        u = u + out_np - x\n",
    "\n",
    "        # show psnrs:\n",
    "        if HR_img is not None:\n",
    "            psnr_net = compare_PSNR(HR_img, out_np, on_y=True)  # psnr of network output\n",
    "            psnr_x_u = compare_PSNR(HR_img, x - u, on_y=True)  # the psnr of our result\n",
    "            print('\\r', algorithm_name, '%04d/%04d Loss %f' % (i, admm_iter, total_loss.item()),\n",
    "                  'psnr net: %.2f psnr x-u: %.2f' % (psnr_net, psnr_x_u), end='')\n",
    "            if i in plot_array:  # plot images\n",
    "                tmp_dict = {'HR': Data(HR_img),\n",
    "                            'LR': Data(y),\n",
    "                            'Net': Data(out_np, psnr_net),\n",
    "                            'x-u': Data(x - u, psnr_x_u),\n",
    "                            'u': Data((u - np.min(u)) / (np.max(u) - np.min(u)))\n",
    "                            }\n",
    "                plot_dict(tmp_dict)\n",
    "        else:\n",
    "            print('\\r', algorithm_name, 'iteration %04d/%04d Loss %f' % (i, admm_iter, total_loss.item()), end='')\n",
    "    if denoiser_thread.is_alive():\n",
    "        denoiser_thread.join()  # joining the thread\n",
    "    return x - u"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preparations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_and_plot(denoiser, name, plot_checkpoints={}):\n",
    "    global data_dict\n",
    "    # print:\n",
    "    print(\"Ready to go, %s -> %s\" % (data_dict[CORRUPTED].img.shape, data_dict[ORIGINAL].img.shape))\n",
    "    # get net, input, and downsampler:\n",
    "    net, net_input = get_network_and_input(img_shape=data_dict[ORIGINAL].img.shape)\n",
    "    ds = Downsampler(n_planes=data_dict[CORRUPTED].img.shape[0], factor=FACTOR,\n",
    "                 kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype)\n",
    "    # run the generic function:\n",
    "    sr_img = train_via_admm(net, net_input, denoiser, ds, data_dict[CORRUPTED].img, data_dict[BICUBIC].img,\n",
    "                            plot_array=plot_checkpoints, algorithm_name=name,\n",
    "                            HR_img=data_dict[ORIGINAL].img)\n",
    "    data_dict[name] = Data(sr_img, compare_PSNR(data_dict[ORIGINAL].img, sr_img, on_y=True))\n",
    "    plot_dict(data_dict)\n",
    "\n",
    "\n",
    "plot_checkpoints = {1, 10, 100, 1000, 1500}\n",
    "run_and_plot(bm3d, DIP_BM3D, plot_checkpoints)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
